from __future__ import annotations

import atexit
import os
import sys
import tarfile
import tempfile
from collections.abc import Callable, Generator
from contextlib import ExitStack, contextmanager
from os.path import exists
from threading import Lock
from typing import TypeVar
from urllib.request import urlopen

if sys.version_info >= (3, 9):
    from importlib.resources import as_file, files
else:
    from importlib_resources import as_file, files

try:
    from .version import __version__  # NOQA
except ImportError:
    raise ImportError("BUG: version.py doesn't exist. Please file a bug report.")

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from .openjtalk import mecab_dict_index as _mecab_dict_index
from .utils import merge_njd_marine_features

_file_manager = ExitStack()
atexit.register(_file_manager.close)

_pyopenjtalk_ref = files(__name__)
_dic_dir_name = "open_jtalk_dic_utf_8-1.11"

# Dictionary directory
# defaults to the package directory where the dictionary will be automatically downloaded
OPEN_JTALK_DICT_DIR = os.environ.get(
    "OPEN_JTALK_DICT_DIR",
    str(_file_manager.enter_context(as_file(_pyopenjtalk_ref / _dic_dir_name))),
).encode("utf-8")
_dict_download_url = "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1"
_DICT_URL = f"{_dict_download_url}/open_jtalk_dic_utf_8-1.11.tar.gz"

# Default mei_normal.voice for HMM-based TTS
DEFAULT_HTS_VOICE = str(
    _file_manager.enter_context(
        as_file(_pyopenjtalk_ref / "htsvoice/mei_normal.htsvoice")
    )
).encode("utf-8")


def _extract_dic():
    from tqdm.auto import tqdm

    global OPEN_JTALK_DICT_DIR
    pyopenjtalk_dir = _file_manager.enter_context(as_file(_pyopenjtalk_ref))
    with tempfile.TemporaryFile() as t:
        print('Downloading: "{}"'.format(_DICT_URL))
        with urlopen(_DICT_URL) as response:
            with tqdm.wrapattr(
                t, "write", total=getattr(response, "length", None)
            ) as tar:
                for chunk in response:
                    tar.write(chunk)
        t.seek(0)
        print("Extracting tar file")
        with tarfile.open(mode="r|gz", fileobj=t) as f:
            f.extractall(path=pyopenjtalk_dir)
    OPEN_JTALK_DICT_DIR = str(pyopenjtalk_dir / _dic_dir_name).encode("utf-8")


def _lazy_init():
    if not exists(OPEN_JTALK_DICT_DIR):
        _extract_dic()


_T = TypeVar("_T")


def _global_instance_manager(
    instance_factory: Callable[[], _T] | None = None, instance: _T | None = None
) -> Callable[[], Generator[_T, None, None]]:
    assert instance_factory is not None or instance is not None
    _instance = instance
    mutex = Lock()

    @contextmanager
    def manager() -> Generator[_T, None, None]:
        nonlocal _instance
        with mutex:
            if _instance is None:
                _instance = instance_factory()
            yield _instance

    return manager


def _jtalk_factory() -> OpenJTalk:
    _lazy_init()
    return OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)


def _marine_factory():
    try:
        from marine.predict import Predictor
    except ImportError:
        raise ImportError("Please install marine by `pip install pyopenjtalk[marine]`")
    return Predictor()


# Global instance of OpenJTalk
_global_jtalk = _global_instance_manager(_jtalk_factory)
# Global instance of HTSEngine
# mei_normal.voice is used as default
_global_htsengine = _global_instance_manager(lambda: HTSEngine(DEFAULT_HTS_VOICE))
# Global instance of Marine
_global_marine = _global_instance_manager(_marine_factory)


def g2p(*args, **kwargs):
    """Grapheme-to-phoeneme (G2P) conversion

    This is just a convenient wrapper around `run_frontend`.

    Args:
        text (str): Unicode Japanese text.
        kana (bool): If True, returns the pronunciation in katakana, otherwise in phone.
          Default is False.
        join (bool): If True, concatenate phones or katakana's into a single string.
          Default is True.

    Returns:
        str or list: G2P result in 1) str if join is True 2) list if join is False.
    """
    with _global_jtalk() as jtalk:
        return jtalk.g2p(*args, **kwargs)


def estimate_accent(njd_features):
    """Accent estimation using marine

    This function requires marine (https://github.com/6gsn/marine)

    Args:
        njd_result (list): features generated by OpenJTalk.

    Returns:
        list: features for NJDNode with estimation results by marine.
    """
    with _global_marine() as marine:
        from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature

        marine_feature = convert_njd_feature_to_marine_feature(njd_features)
        marine_results = marine.predict(
            [marine_feature], require_open_jtalk_format=True
        )
    njd_features = merge_njd_marine_features(njd_features, marine_results)
    return njd_features


def extract_fullcontext(text, run_marine=False):
    """Extract full-context labels from text

    Args:
        text (str): Input text
        run_marine (bool): Whether to estimate accent using marine.
          Default is False. If you want to activate this option, you need to install marine
          by `pip install pyopenjtalk[marine]`

    Returns:
        list: List of full-context labels
    """

    njd_features = run_frontend(text)
    if run_marine:
        njd_features = estimate_accent(njd_features)
    return make_label(njd_features)


def synthesize(labels, speed=1.0, half_tone=0.0):
    """Run OpenJTalk's speech synthesis backend

    Args:
        labels (list): Full-context labels
        speed (float): speech speed rate. Default is 1.0.
        half_tone (float): additional half-tone. Default is 0.

    Returns:
        np.ndarray: speech waveform (dtype: np.float64)
        int: sampling frequency (defualt: 48000)
    """
    if isinstance(labels, tuple) and len(labels) == 2:
        labels = labels[1]

    with _global_htsengine() as htsengine:
        sr = htsengine.get_sampling_frequency()
        htsengine.set_speed(speed)
        htsengine.add_half_tone(half_tone)
        return htsengine.synthesize(labels), sr


def tts(text, speed=1.0, half_tone=0.0, run_marine=False):
    """Text-to-speech

    Args:
        text (str): Input text
        speed (float): speech speed rate. Default is 1.0.
        half_tone (float): additional half-tone. Default is 0.
        run_marine (bool): Whether to estimate accent using marine.
          Default is False. If you want activate this option, you need to install marine
          by `pip install pyopenjtalk[marine]`

    Returns:
        np.ndarray: speech waveform (dtype: np.float64)
        int: sampling frequency (defualt: 48000)
    """
    return synthesize(
        extract_fullcontext(text, run_marine=run_marine), speed, half_tone
    )


def run_frontend(text):
    """Run OpenJTalk's text processing frontend

    Args:
        text (str): Unicode Japanese text.

    Returns:
        list: features for NJDNode.
    """
    with _global_jtalk() as jtalk:
        return jtalk.run_frontend(text)


def make_label(njd_features):
    """Make full-context label using features

    Args:
        njd_features (list): features for NJDNode.

    Returns:
        list: full-context labels.
    """
    with _global_jtalk() as jtalk:
        return jtalk.make_label(njd_features)


def mecab_dict_index(path, out_path, dn_mecab=None):
    """Create user dictionary

    Args:
        path (str): path to user csv
        out_path (str): path to output dictionary
        dn_mecab (optional. str): path to mecab dictionary
    """
    if not exists(path):
        raise FileNotFoundError("no such file or directory: %s" % path)
    if dn_mecab is None:
        with _global_jtalk():  # call _lazy_init()
            pass
        dn_mecab = OPEN_JTALK_DICT_DIR
    r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8"))

    # NOTE: mecab load returns 1 if success, but mecab_dict_index return the opposite
    # yeah it's confusing...
    if r != 0:
        raise RuntimeError("Failed to create user dictionary")


def update_global_jtalk_with_user_dict(path):
    """Update global openjtalk instance with the user dictionary

    Note that this will change the global state of the openjtalk module.

    Args:
        path (str): path to user dictionary
    """
    global _global_jtalk
    with _global_jtalk():
        if not exists(path):
            raise FileNotFoundError("no such file or directory: %s" % path)
        _global_jtalk = _global_instance_manager(
            instance=OpenJTalk(
                dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
            )
        )
